# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List, Tuple

import torch
from tabulate import tabulate
from tqdm import tqdm

# Assuming these imports based on the kernel location
from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.blockwise_fp8_training.kernels import (
    torch_blockwise_scale_act_quant_lhs,
    triton_fp8_blockwise_act_quant_lhs,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
    input_shape: Tuple[int, int]  # (M, K)
    block_size: int


@dataclass(frozen=True)
class ExperimentResult:
    # time
    torch_us: float
    triton_us: float
    # mem bw
    torch_gbps: float
    triton_gbps: float


@dataclass(frozen=True)
class Experiment:
    config: ExperimentConfig
    result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
    """
    Test configurations for typical transformer activation shapes.
    Format: (batch_size * seq_len, hidden_dim)
    """
    # Llama-style shapes: various batch*seq_len sizes with typical hidden dims
    input_shapes = [
        (512, 4096),
        (1024, 4096),
        (2048, 4096),
        (4096, 4096),
        (8192, 4096),
        (16384, 4096),
        (32768, 4096),
        (65536, 4096),
        (131_072, 4096),
    ]

    configs = []
    block_sizes = [128]  # Standard block size for FP8

    for shape in input_shapes:
        for block_size in block_sizes:
            configs.append(
                ExperimentConfig(
                    input_shape=shape,
                    block_size=block_size,
                )
            )
    return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
    M, K = config.input_shape
    block_size = config.block_size

    def verify_outputs(
        y_torch: torch.Tensor,
        s_torch: torch.Tensor,
        y_triton: torch.Tensor,
        s_triton: torch.Tensor,
        rtol: float = 1e-2,
        atol: float = 1e-2,
    ):
        """Verify that Triton and torch implementations produce similar results."""

        # Convert FP8 back to float for comparison
        y_torch_float = y_torch.to(torch.float32)
        y_triton_float = y_triton.to(torch.float32)

        assert y_torch.shape == y_triton.shape, (
            f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}"
        )
        assert y_torch.stride() == y_triton.stride(), (
            f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}"
        )

        assert s_torch.shape == s_triton.shape, (
            f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}"
        )
        assert s_torch.stride() == s_triton.stride(), (
            f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}"
        )

        # Check quantized values are close

        torch.testing.assert_close(
            y_torch_float,
            y_triton_float,
            rtol=rtol,
            atol=atol,
            msg="Quantized values differ between torch and Triton implementations",
        )

        torch.testing.assert_close(
            s_torch,
            s_triton,
            rtol=rtol,
            atol=atol,
            msg="Scales differ between torch and Triton implementations",
        )

    input_tensor = torch.randn(
        M,
        K,
        dtype=torch.bfloat16,
        device=device,
    )

    # Benchmark torch implementation
    torch_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs)
    y_torch, s_torch = torch_impl_c(input_tensor, block_size)
    torch_time_us = benchmark_cuda_function_in_microseconds(
        torch_impl_c,
        input_tensor,
        block_size,
    )

    # Benchmark Triton implementation
    y_triton, s_triton = triton_fp8_blockwise_act_quant_lhs(input_tensor, block_size)
    triton_time_us = benchmark_cuda_function_in_microseconds(
        triton_fp8_blockwise_act_quant_lhs,
        input_tensor,
        block_size,
    )

    # Verify correctness (optional, can comment out for pure benchmarking)
    verify_outputs(y_torch, s_torch, y_triton, s_triton)

    # Memory bandwidth calculations
    bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8
    bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8
    bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8

    read_bytes = input_tensor.numel() * bytes_per_input_el
    write_bytes = (
        y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el
    )

    torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
    triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)

    return ExperimentResult(
        torch_us=torch_time_us,
        triton_us=triton_time_us,
        torch_gbps=torch_gbps,
        triton_gbps=triton_gbps,
    )


def print_results(experiments: List[Experiment]):
    headers = [
        "input_shape (M, K)",
        "block_size",
        "torch_us",
        "triton_us",
        "speedup",
        "torch_gbps",
        "triton_gbps",
    ]
    rows = []
    for experiment in experiments:
        speedup = experiment.result.torch_us / experiment.result.triton_us
        rows.append(
            [
                f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}",
                experiment.config.block_size,
                f"{experiment.result.torch_us:.2f}",
                f"{experiment.result.triton_us:.2f}",
                f"{speedup:.2f}x",
                f"{experiment.result.torch_gbps:.1f}",
                f"{experiment.result.triton_gbps:.1f}",
            ]
        )
    print(tabulate(rows, headers=headers, tablefmt="grid"))


def main():
    torch.random.manual_seed(123)
    configs = get_configs()
    results = []

    print(f"Running {len(configs)} benchmark configurations...\n")

    for config in tqdm(configs, desc="Benchmarking"):
        result = run_experiment(config)
        results.append(Experiment(config=config, result=result))

    print("\n" + "=" * 80)
    print("BENCHMARK RESULTS")
    print("=" * 80 + "\n")
    print_results(results)


if __name__ == "__main__":
    main()
